from baselines.load_data import load_mnist_1d, Bandit_multi,load_mnist_adv,synthetic
from EENet import EE_Net
import numpy as np
import os


if __name__ == '__main__':
    dataset = ['covertype','MagicTelescope','shuttle','adult','mushroom','fashion']
    #dataset = ['mnist']
    dataset = ['cos']
    #dataset = ['covertype','MagicTelescope','shuttle','mushroom','fashion','Plants']
    dataset = ['covertype','MagicTelescope','shuttle','mushroom','fashion']
    #dataset = ['leaf', 'eucalyptus']
    dataset = ['Plants','shuttle']
    #dataset = ['quad', 'cos' ]
    dataset = ['covertype','MagicTelescope','shuttle','mushroom','fashion','Plants', 'adult']
    dataset = [ 'Plants', 'fashion', 'mushroom', 'adult']
    
    dataset = ['covertype','MagicTelescope','shuttle','mushroom','fashion','Plants']
    
    for d in dataset:
        runing_times = 20
        regrets_all = []
        #load_mnist_adv()
        if d == 'mnist':
            b = load_mnist_adv()
        elif d == 'cos' or d == 'square' or d == 'quad':
            b = synthetic(d)
        else:
            b = Bandit_multi(d)
        for i in range(runing_times):  
            
            if d == 'covertype':
                lr_1 = 0.01 #learning rate for exploitation network
                lr_2 = 0.01 #learning rate for exploration network
                lr_3 = 0.005 #learning rate for decision maker
            elif d == 'MagicTelescope':
                lr_1 = 0.005 #learning rate for exploitation network
                lr_2 = 0.005 #learning rate for exploration network
                lr_3 = 0.01 #learning rate for decision maker
            elif d == 'shuttle':
                lr_1 = 0.005 #learning rate for exploitation network
                lr_2 = 0.005 #learning rate for exploration network
                lr_3 = 0.005 #learning rate for decision maker
            elif d == 'mushroom':
                lr_1 = 0.001 #learning rate for exploitation network
                lr_2 = 0.01 #learning rate for exploration network
                lr_3 = 0.005 #learning rate for decision maker
            elif d == 'fashion':
                lr_1 = 0.01 #learning rate for exploitation network
                lr_2 = 0.01 #learning rate for exploration network
                lr_3 = 0.001 #learning rate for decision maker
            elif d == 'Plants':
                lr_1 = 0.001 #learning rate for exploitation network
                lr_2 = 0.001 #learning rate for exploration network
                lr_3 = 0.005 #learning rate for decision maker

            regrets = []
            sum_regret = 0
            ee_net = EE_Net(b.dim, b.n_arm, pool_step_size = 50, lr_1 = lr_1, lr_2 = lr_2, lr_3 = lr_3,  hidden=100, neural_decision_maker = False)
            block = 500
            error = np.zeros(b.n_arm)
            count = np.zeros(b.n_arm)
            for t in range(5000):
                '''Draw input sample'''
                if t < block:
                    context, rwd, arm = b.step(-1)
                elif t%block == 0:
                    k = np.argmax(error/count)
                    print(error,count,error/count,k)
                    context, rwd, arm = b.step(k)
                else:
                    context, rwd, arm = b.step(k)
                arm_select = ee_net.predict(context, t)

                reward = rwd[arm_select]
                regret = np.max(rwd) - reward
                count[arm] +=1
                if reward==0:
                    error[arm] += 1
                ee_net.update(context, reward, t)

                sum_regret += regret
                if t<1000:
                    if t%10 == 0:
                        loss_1, loss_2, loss_3  = ee_net.train(t)

                else:
                    if t%100 == 0:
                        loss_1, loss_2, loss_3  = ee_net.train(t)

                regrets.append(sum_regret)
                if t % 50 == 0:
                    print('round:{}, regret: {:},  average_regret: {:.3f}, loss_1:{:.4f}, loss_2:{:.4f}, loss_3:{:.4f}'.format(t,sum_regret, sum_regret/(t+1), loss_1, loss_2, loss_3))
            print(' regret: {:},  average_regret: {:.2f}'.format(sum_regret, sum_regret/(t+1)))
            regrets_all.append(regrets)
        path = os.getcwd()
        np.save('{}/results/iclr_results/eenet/eenet_results_{}.npy'.format(path,d), regrets_all)
